{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "np.random.seed()\n",
    "# tf.set_random_seed()\n",
    "\n",
    "'''\n",
    "CFO Equalizer is a net that, given a known preamble, will take in the convolved preamble and convolved data\n",
    "that have been passed through a multipath channel, and correct the CFO and equalize the data.\n",
    "For now we assume that the channel taps are known.\n",
    "'''\n",
    "\n",
    "class CFO_Equalizer():\n",
    "    def __init__(self, preamble_length = 50, data_length = 100, channel_length = 2, \n",
    "                 batch_size = 150, learning_rate = 0.01):           \n",
    "        \n",
    "        # Network variables\n",
    "        self.preamble_length = preamble_length\n",
    "        self.data_length = data_length\n",
    "        self.channel_length = channel_length\n",
    "        self.batch_size = batch_size\n",
    "        self.learning_rate = learning_rate\n",
    "        \n",
    "        # Placeholders for training\n",
    "        self.preamble_original = tf.placeholder(tf.float32, [None, self.preamble_length, 2])\n",
    "        self.preamble_convolved = tf.placeholder(tf.float32, [None, self.preamble_length, 2])\n",
    "        self.data_original = tf.placeholder(tf.float32, [None, self.data_length, 2])\n",
    "        self.data_convolved = tf.placeholder(tf.float32, [None, self.data_length, 2])\n",
    "        self.channel = tf.placeholder(tf.float32, [None, self.channel_length,2])\n",
    "            \n",
    "        ###############\n",
    "        # CFO Estimation using preambles\n",
    "        ###############\n",
    "        preamble_original_cfo = tf.concat([self.preamble_original, self.preamble_convolved], 1)\n",
    "\n",
    "        preamble_original_cfo_flat = tf.contrib.layers.flatten(preamble_original_cfo)\n",
    "\n",
    "        cfo_layer_1 = tf.layers.dense(\n",
    "          preamble_original_cfo_flat, 100, tf.nn.tanh, use_bias=True)\n",
    "        cfo_layer_2 = tf.layers.dense(\n",
    "          cfo_layer_1, 100, tf.nn.tanh, use_bias=True)\n",
    "        cfo_layer_3 = tf.layers.dense(\n",
    "          cfo_layer_2, 100, activation=tf.nn.tanh, use_bias=True)\n",
    "        cfo_layer_4 = tf.layers.dense(\n",
    "          cfo_layer_3, 1, activation=tf.identity, use_bias=True)\n",
    "\n",
    "        self.est_omega = cfo_layer_4\n",
    "        \n",
    "        \n",
    "        ###############\n",
    "        # CFO Correction on the convolved data\n",
    "        ###############\n",
    "\n",
    "        data_cfo_complex = tf.complex(self.data_convolved[:,:,0], self.data_convolved[:,:,1])\n",
    "        data_cfo_complex = tf.expand_dims(data_cfo_complex, -1)\n",
    "\n",
    "        # build the rotation matrix\n",
    "        incremented_omega = []\n",
    "        for i in range(data_length):\n",
    "            incremented_omega.append(-self.est_omega*i)\n",
    "\n",
    "        rotation_complex = tf.exp(tf.complex(0.0,tf.transpose(incremented_omega, perm=[1,0,2])))\n",
    "        data_rotated_complex = tf.multiply(data_cfo_complex,rotation_complex)\n",
    "\n",
    "        data_rotated = tf.stack([tf.real(data_rotated_complex[:,:]), \n",
    "                                     tf.imag(data_rotated_complex[:,:])], axis=2)\n",
    "        \n",
    "        data_cfo_corrected = tf.squeeze(data_rotated,-1)\n",
    "    \n",
    "        ###############\n",
    "        # Equalization on the cfo corrected data\n",
    "        ###############\n",
    "#         print(data_cfo_corrected[:,:,0,:].get_shape().as_list())\n",
    "#         print(self.channel[:,:,0,:].get_shape().as_list())\n",
    "        \n",
    "        est_data = []\n",
    "        for val in range(2):\n",
    "            data_channel_concat = tf.concat([data_cfo_corrected[:,:,val], self.channel[:,:,val]], 1)\n",
    "#             data_channel_concat.set_shape([None, 102])\n",
    "            layer_1 = tf.layers.dense(\n",
    "              data_channel_concat, 400, tf.nn.sigmoid, use_bias=True)\n",
    "            layer_2 = tf.layers.dense(\n",
    "              layer_1, 400, tf.nn.sigmoid, use_bias=True)\n",
    "            layer_3 = tf.layers.dense(\n",
    "              layer_2, data_length, activation=tf.identity, use_bias=True)\n",
    "\n",
    "            est_data_part = layer_3\n",
    "            est_data.append(est_data_part)\n",
    "            \n",
    "        self.est_data = tf.transpose(est_data, perm=[1,2,0])\n",
    "\n",
    "#         print(est_data.get_shape().as_list())    \n",
    "        \n",
    "        ###############################################\n",
    "        # Define surrogate loss and optimization tensor\n",
    "        ###############################################\n",
    "        self.surr = tf.losses.mean_squared_error(self.data_original, self.est_data)\n",
    "        self.optimizer = tf.train.AdamOptimizer(self.learning_rate)\n",
    "        self.update_op = self.optimizer.minimize(self.surr)\n",
    "\n",
    "        ###############\n",
    "        # Start session\n",
    "        ###############\n",
    "        self.sess = tf.Session()\n",
    "        self.sess.run(tf.global_variables_initializer())\n",
    "            \n",
    "\n",
    "    def test_net(self, preamble_original, preamble_convolved, data_original, data_convolved, channel):\n",
    "        \"\"\"\n",
    "        Test net function. Evaluates the session\n",
    "        Inputs:\n",
    "            self.preamble_original, the original preamble that was sent prior to noise, CFO, and channel\n",
    "            self.preamble_convolved, preamble after noise, CFO, and channel\n",
    "            self.data_original, the original data that was sent prior to noise, CFO, and channel\n",
    "            self.data_convolved, data after noise, CFO, and channel\n",
    "            self.channel, channel that is applied to the I and Q of the preamble and data streams\n",
    "        Outputs:\n",
    "            test_cost, error between the data_original and the estimated data\n",
    "        \"\"\"\n",
    "\n",
    "        test_cost, test_est, test_omega = self.sess.run([self.surr,self.est_data,self.est_omega], feed_dict={\n",
    "                self.preamble_original: preamble_original,\n",
    "                self.preamble_convolved: preamble_convolved,\n",
    "                self.data_original: data_original,\n",
    "                self.data_convolved: data_convolved,\n",
    "                self.channel: channel\n",
    "        })\n",
    "        \n",
    "        return test_cost,test_est, test_omega\n",
    "    \n",
    "            \n",
    "    def train_net(self, preamble_original, preamble_convolved, data_original, data_convolved, channel):\n",
    "        \"\"\"\n",
    "        Train net function. Calls self.update_op.\n",
    "        Inputs:\n",
    "            self.preamble_original, the original preamble that was sent prior to noise, CFO, and channel\n",
    "            self.preamble_convolved, preamble after noise, CFO, and channel\n",
    "            self.data_original, the original data that was sent prior to noise, CFO, and channel\n",
    "            self.data_convolved, data after noise, CFO, and channel\n",
    "            self.channel, channel that is applied to the I and Q of the preamble and data streams\n",
    "        Outputs:\n",
    "            train_cost, error between the data_original and the estimated data\n",
    "        \"\"\"\n",
    "\n",
    "        _ , train_cost = self.sess.run([self.update_op, self.surr], feed_dict={\n",
    "                self.preamble_original: preamble_original,\n",
    "                self.preamble_convolved: preamble_convolved,\n",
    "                self.data_original: data_original,\n",
    "                self.data_convolved: data_convolved,\n",
    "                self.channel: channel\n",
    "        })\n",
    "        \n",
    "        return train_cost\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "np.random.seed()\n",
    "# tf.set_random_seed()\n",
    "\n",
    "'''\n",
    "CFO Equalizer is a net that, given a known preamble, will take in the convolved preamble and convolved data\n",
    "that have been passed through a multipath channel, and correct the CFO and equalize the data.\n",
    "For now we assume that the channel taps are known.\n",
    "'''\n",
    "\n",
    "class Equalizer_then_CFO():\n",
    "    def __init__(self, preamble_length = 50, data_length = 100, channel_length = 2, \n",
    "                 batch_size = 150, learning_rate = 0.01):           \n",
    "        \n",
    "        # Network variables\n",
    "        self.preamble_length = preamble_length\n",
    "        self.data_length = data_length\n",
    "        self.channel_length = channel_length\n",
    "        self.batch_size = batch_size\n",
    "        self.learning_rate = learning_rate\n",
    "        \n",
    "        # Placeholders for training\n",
    "        self.preamble_original = tf.placeholder(tf.float32, [None, self.preamble_length, 2])\n",
    "        self.preamble_convolved = tf.placeholder(tf.float32, [None, self.preamble_length, 2])\n",
    "        self.data_original = tf.placeholder(tf.float32, [None, self.data_length, 2])\n",
    "        self.data_convolved = tf.placeholder(tf.float32, [None, self.data_length, 2])\n",
    "        self.channel = tf.placeholder(tf.float32, [None, self.channel_length,2])\n",
    "            \n",
    "            \n",
    "        ###############\n",
    "        # Equalization on the cfo corrected data\n",
    "        ###############\n",
    "#         print(data_cfo_corrected[:,:,0,:].get_shape().as_list())\n",
    "#         print(self.channel[:,:,0,:].get_shape().as_list())\n",
    "        \n",
    "        est_data = []\n",
    "        for val in range(2):\n",
    "            data_channel_concat = tf.concat([self.preamble_convolved[:,:,val], self.channel[:,:,val]], 1)\n",
    "#             data_channel_concat.set_shape([None, 102])\n",
    "            layer_1 = tf.layers.dense(\n",
    "              data_channel_concat, 400, tf.nn.sigmoid, use_bias=True)\n",
    "            layer_2 = tf.layers.dense(\n",
    "              layer_1, 400, tf.nn.sigmoid, use_bias=True)\n",
    "            layer_3 = tf.layers.dense(\n",
    "              layer_2, preamble_length, activation=tf.identity, use_bias=True)\n",
    "\n",
    "            est_data_part = layer_3\n",
    "            est_data.append(est_data_part)\n",
    "            \n",
    "        self.est_data = tf.transpose(est_data, perm=[1,2,0])\n",
    "\n",
    "#         print(est_data.get_shape().as_list())        \n",
    "        \n",
    "        \n",
    "        ###############\n",
    "        # CFO Estimation using preambles\n",
    "        ###############\n",
    "        preamble_original_cfo = tf.concat([self.preamble_original, self.est_data], 1)\n",
    "\n",
    "        preamble_original_cfo_flat = tf.contrib.layers.flatten(preamble_original_cfo)\n",
    "\n",
    "        cfo_layer_1 = tf.layers.dense(\n",
    "          preamble_original_cfo_flat, 100, tf.nn.tanh, use_bias=True)\n",
    "        cfo_layer_2 = tf.layers.dense(\n",
    "          cfo_layer_1, 100, tf.nn.tanh, use_bias=True)\n",
    "        cfo_layer_3 = tf.layers.dense(\n",
    "          cfo_layer_2, 100, activation=tf.nn.tanh, use_bias=True)\n",
    "        cfo_layer_4 = tf.layers.dense(\n",
    "          cfo_layer_3, 1, activation=tf.identity, use_bias=True)\n",
    "\n",
    "        self.est_omega = cfo_layer_4\n",
    "        \n",
    "        \n",
    "        ###############\n",
    "        # CFO Correction on the convolved data\n",
    "        ###############\n",
    "\n",
    "        data_cfo_complex = tf.complex(self.est_data[:,:,0], self.est_data[:,:,1])\n",
    "        data_cfo_complex = tf.expand_dims(data_cfo_complex, -1)\n",
    "\n",
    "        # build the rotation matrix\n",
    "        incremented_omega = []\n",
    "        for i in range(preamble_length):\n",
    "            incremented_omega.append(-self.est_omega*i)\n",
    "\n",
    "        rotation_complex = tf.exp(tf.complex(0.0,tf.transpose(incremented_omega, perm=[1,0,2])))\n",
    "        data_rotated_complex = tf.multiply(data_cfo_complex,rotation_complex)\n",
    "\n",
    "        data_rotated = tf.stack([tf.real(data_rotated_complex[:,:]), \n",
    "                                     tf.imag(data_rotated_complex[:,:])], axis=2)\n",
    "        \n",
    "        self.data_cfo_corrected = tf.squeeze(data_rotated,-1)\n",
    "    \n",
    "        \n",
    "        ###############################################\n",
    "        # Define surrogate loss and optimization tensor\n",
    "        ###############################################\n",
    "        self.surr = tf.losses.mean_squared_error(self.preamble_original, self.data_cfo_corrected)\n",
    "        self.optimizer = tf.train.AdamOptimizer(self.learning_rate)\n",
    "        self.update_op = self.optimizer.minimize(self.surr)\n",
    "\n",
    "        ###############\n",
    "        # Start session\n",
    "        ###############\n",
    "        self.sess = tf.Session()\n",
    "        self.sess.run(tf.global_variables_initializer())\n",
    "            \n",
    "\n",
    "    def test_net(self, preamble_original, preamble_convolved, data_original, data_convolved, channel):\n",
    "        \"\"\"\n",
    "        Test net function. Evaluates the session\n",
    "        Inputs:\n",
    "            self.preamble_original, the original preamble that was sent prior to noise, CFO, and channel\n",
    "            self.preamble_convolved, preamble after noise, CFO, and channel\n",
    "            self.data_original, the original data that was sent prior to noise, CFO, and channel\n",
    "            self.data_convolved, data after noise, CFO, and channel\n",
    "            self.channel, channel that is applied to the I and Q of the preamble and data streams\n",
    "        Outputs:\n",
    "            test_cost, error between the data_original and the estimated data\n",
    "        \"\"\"\n",
    "\n",
    "        test_cost, test_est, test_omega, test_corr = self.sess.run([self.surr,self.est_data,self.est_omega, \n",
    "                                                                    self.data_cfo_corrected], \n",
    "                                                                   feed_dict={\n",
    "                self.preamble_original: preamble_original,\n",
    "                self.preamble_convolved: preamble_convolved,\n",
    "                self.data_original: data_original,\n",
    "                self.data_convolved: data_convolved,\n",
    "                self.channel: channel\n",
    "        })\n",
    "        \n",
    "        return test_cost,test_est, test_omega, test_corr\n",
    "    \n",
    "            \n",
    "    def train_net(self, preamble_original, preamble_convolved, data_original, data_convolved, channel):\n",
    "        \"\"\"\n",
    "        Train net function. Calls self.update_op.\n",
    "        Inputs:\n",
    "            self.preamble_original, the original preamble that was sent prior to noise, CFO, and channel\n",
    "            self.preamble_convolved, preamble after noise, CFO, and channel\n",
    "            self.data_original, the original data that was sent prior to noise, CFO, and channel\n",
    "            self.data_convolved, data after noise, CFO, and channel\n",
    "            self.channel, channel that is applied to the I and Q of the preamble and data streams\n",
    "        Outputs:\n",
    "            train_cost, error between the data_original and the estimated data\n",
    "        \"\"\"\n",
    "\n",
    "        _ , train_cost = self.sess.run([self.update_op, self.surr], feed_dict={\n",
    "                self.preamble_original: preamble_original,\n",
    "                self.preamble_convolved: preamble_convolved,\n",
    "                self.data_original: data_original,\n",
    "                self.data_convolved: data_convolved,\n",
    "                self.channel: channel\n",
    "        })\n",
    "        \n",
    "        return train_cost\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[None, 2, 2]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs  = tf.placeholder(tf.float32, [None, 2, 2])\n",
    "\n",
    "inputs.get_shape().as_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_generation(epochs):\n",
    "    \n",
    "    import numpy as np\n",
    "    import scipy.signal as sig\n",
    "    from numpy import linalg as LA\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    # Create training and test data for the NN\n",
    "    batch_size = 1500\n",
    "\n",
    "    snr=50\n",
    "\n",
    "    # number of random preambles that will be trained and tested on\n",
    "    num_train=epochs*batch_size\n",
    "    num_test=batch_size\n",
    "\n",
    "    preamble_length = 50\n",
    "    data_length = 100\n",
    "    channel_length = 2\n",
    "\n",
    "    # QPSK\n",
    "    preamble_train_orig = np.zeros((num_train,preamble_length,2))\n",
    "    preamble_train_cfo = np.zeros((num_train,preamble_length,2))\n",
    "    preamble_test_orig = np.zeros((num_test,preamble_length,2))\n",
    "    preamble_test_cfo = np.zeros((num_test,preamble_length,2))\n",
    "\n",
    "    # the max value the CFO rate can be \n",
    "    max_omega = 1/50\n",
    "    # the cfo rate is in number of radians turned per sample\n",
    "    omega_train = np.random.uniform(low=0,high=max_omega, size=((num_train,1)))\n",
    "    omega_test = np.random.uniform(low=0,high=max_omega, size=((num_test,1)))\n",
    "    \n",
    "    # assume the channel length is the same for imaginary and real\n",
    "    channel_train = np.zeros((num_train,channel_length,2))\n",
    "    channel_test = np.zeros((num_test,channel_length,2))\n",
    "    \n",
    "    data_train_orig = np.zeros((num_train, data_length,2))\n",
    "    data_train_cfo = np.zeros((num_train,data_length,2))\n",
    "    data_test_orig = np.zeros((num_test,data_length,2))\n",
    "    data_test_cfo = np.zeros((num_test,data_length,2))\n",
    "\n",
    "    for i in range(num_train):\n",
    "        # original preamble - known to both TX and RX\n",
    "        preamble_train_orig[i,:,:] = ((2*np.random.randint(2,size=(preamble_length,2)))-1)*np.sqrt(2)/2\n",
    "        # add AWG noise to the preamble\n",
    "        if snr > 0:\n",
    "            preamble_noisy = (1./np.sqrt(snr)) * np.random.randn(preamble_length,2)+preamble_train_orig[i,:,:]\n",
    "        \n",
    "        \n",
    "        preamble_channel = np.zeros((preamble_length,2))\n",
    "        for k in range(2):\n",
    "#             channel_train[i,:,k]=np.random.uniform(0,1,channel_length)\n",
    "            channel_train[i,:,k]=([1,0])\n",
    "            # normalize power to 1\n",
    "            channel_train[i,:,k] = channel_train[i,:,k]/(np.linalg.norm(channel_train[i,:,k]))\n",
    "            \n",
    "            # apply the channel to the preamble\n",
    "            preamble_channel[:,k] = sig.convolve(preamble_noisy[:,k], channel_train[i,:,k], mode='same')\n",
    "        \n",
    "        # rotate data according to cfo\n",
    "        for j in range(preamble_length):\n",
    "            preamble_train_cfo[i,j,0]=(preamble_channel[j,0]*np.cos(omega_train[i]*j)\n",
    "                                           -preamble_channel[j,1]*np.sin(omega_train[i]*j))\n",
    "            preamble_train_cfo[i,j,1]=(preamble_channel[j,0]*np.sin(omega_train[i]*j)\n",
    "                                          +preamble_channel[j,1]*np.cos(omega_train[i]*j))\n",
    "            \n",
    "        \n",
    "        # original data - known to TX\n",
    "        data_train_orig[i,:,:] = ((2*np.random.randint(2,size=(data_length,2)))-1)*np.sqrt(2)/2\n",
    "        # add AWG noise to the data\n",
    "        if snr > 0:\n",
    "            data_noisy = (1./np.sqrt(snr)) * np.random.randn(data_length,2)+data_train_orig[i,:,:]\n",
    "        \n",
    "        \n",
    "        data_channel = np.zeros((data_length,2))\n",
    "        for k in range(2):\n",
    "            # apply the channel to the data\n",
    "            data_channel[:,k] = sig.convolve(data_noisy[:,k], data_train_orig[i,:,k], mode='same')\n",
    "        \n",
    "        # rotate data according to cfo\n",
    "        for j in range(data_length):\n",
    "            data_train_cfo[i,j,0]=(data_channel[j,0]*np.cos(omega_train[i]*j)\n",
    "                                           -data_channel[j,1]*np.sin(omega_train[i]*j))\n",
    "            data_train_cfo[i,j,1]=(data_channel[j,0]*np.sin(omega_train[i]*j)\n",
    "                                          +data_channel[j,1]*np.cos(omega_train[i]*j))\n",
    "            \n",
    "        if i % 100000 == 0:\n",
    "            print(i)\n",
    "\n",
    "\n",
    "    for i in range(0, num_test):\n",
    "        # original preamble - known to both TX and RX\n",
    "        preamble_test_orig[i,:,:] = ((2*np.random.randint(2,size=(preamble_length,2)))-1)*np.sqrt(2)/2\n",
    "        # add AWG noise to the preamble\n",
    "        if snr > 0:\n",
    "            preamble_noisy = (1./np.sqrt(snr)) * np.random.randn(preamble_length,2)+preamble_test_orig[i,:,:]\n",
    "        \n",
    "        \n",
    "        preamble_channel = np.zeros((preamble_length,2))\n",
    "        for k in range(2):\n",
    "#             channel_test[i,:,k]=np.random.uniform(0,1,channel_length)\n",
    "            channel_test[i,:,k]=([1,0])\n",
    "            # normalize power to 1\n",
    "            channel_test[i,:,k] = channel_test[i,:,k]/(np.linalg.norm(channel_test[i,:,k]))\n",
    "            \n",
    "            # apply the channel to the preamble\n",
    "            preamble_channel[:,k] = sig.convolve(preamble_noisy[:,k], channel_test[i,:,k], mode='same')\n",
    "        \n",
    "        # rotate data according to cfo\n",
    "        for j in range(preamble_length):\n",
    "            preamble_test_cfo[i,j,0]=(preamble_channel[j,0]*np.cos(omega_test[i]*j)\n",
    "                                           -preamble_channel[j,1]*np.sin(omega_test[i]*j))\n",
    "            preamble_test_cfo[i,j,1]=(preamble_channel[j,0]*np.sin(omega_test[i]*j)\n",
    "                                          +preamble_channel[j,1]*np.cos(omega_test[i]*j))\n",
    "            \n",
    "        \n",
    "        # original data - known to TX\n",
    "        data_test_orig[i,:,:] = ((2*np.random.randint(2,size=(data_length,2)))-1)*np.sqrt(2)/2\n",
    "        # add AWG noise to the data\n",
    "        if snr > 0:\n",
    "            data_noisy = (1./np.sqrt(snr)) * np.random.randn(data_length,2)+data_test_orig[i,:,:]\n",
    "        \n",
    "        \n",
    "        data_channel = np.zeros((data_length,2))\n",
    "        for k in range(2):\n",
    "            # apply the channel to the data\n",
    "            data_channel[:,k] = sig.convolve(data_noisy[:,k], channel_test[i,:,k], mode='same')\n",
    "        \n",
    "        # rotate data according to cfo\n",
    "        for j in range(data_length):\n",
    "            data_test_cfo[i,j,0]=(data_channel[j,0]*np.cos(omega_test[i]*j)\n",
    "                                           -data_channel[j,1]*np.sin(omega_test[i]*j))\n",
    "            data_test_cfo[i,j,1]=(data_channel[j,0]*np.sin(omega_test[i]*j)\n",
    "                                          +data_channel[j,1]*np.cos(omega_test[i]*j))\n",
    "    \n",
    "    print(\"Data generation complete.\")\n",
    "\n",
    "    \n",
    "    return (channel_train, channel_test, preamble_train_orig, preamble_train_cfo, preamble_test_orig, \n",
    "            preamble_test_cfo, data_train_orig,data_train_cfo,data_test_orig,data_test_cfo, omega_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "100000\n",
      "200000\n",
      "300000\n",
      "400000\n",
      "500000\n",
      "600000\n",
      "700000\n",
      "800000\n",
      "900000\n",
      "1000000\n",
      "1100000\n",
      "1200000\n",
      "1300000\n",
      "1400000\n",
      "1500000\n",
      "1600000\n",
      "1700000\n",
      "1800000\n",
      "1900000\n",
      "2000000\n",
      "2100000\n",
      "2200000\n",
      "2300000\n",
      "2400000\n",
      "2500000\n",
      "2600000\n",
      "2700000\n",
      "2800000\n",
      "2900000\n",
      "Data generation complete.\n"
     ]
    }
   ],
   "source": [
    "epochs = 2000\n",
    "\n",
    "# generate the data\n",
    "(channel_train, channel_test, preamble_train_orig, preamble_train_cfo, preamble_test_orig, preamble_test_cfo, \n",
    " data_train_orig,data_train_cfo,data_test_orig,data_test_cfo, omega_test) = data_generation(epochs)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Train Cost 0.996038556098938, Test Cost: 2.978325128555298\n",
      "Epoch 100, Train Cost 0.5000013113021851, Test Cost: 0.5000090003013611\n",
      "Epoch 200, Train Cost 0.5000097751617432, Test Cost: 0.5000061988830566\n",
      "Epoch 300, Train Cost 0.5000213384628296, Test Cost: 0.500011682510376\n",
      "Epoch 400, Train Cost 0.5000109672546387, Test Cost: 0.4999953508377075\n",
      "Epoch 500, Train Cost 0.5000241994857788, Test Cost: 0.500044047832489\n",
      "Epoch 600, Train Cost 0.5000131726264954, Test Cost: 0.5000127553939819\n",
      "Epoch 700, Train Cost 0.5000501275062561, Test Cost: 0.5000619888305664\n",
      "Epoch 800, Train Cost 0.500013530254364, Test Cost: 0.5000439286231995\n",
      "Epoch 900, Train Cost 0.500022292137146, Test Cost: 0.5000061988830566\n",
      "Epoch 1000, Train Cost 0.5000419616699219, Test Cost: 0.5000361800193787\n",
      "Epoch 1100, Train Cost 0.50003981590271, Test Cost: 0.500011682510376\n",
      "Epoch 1200, Train Cost 0.5000222325325012, Test Cost: 0.5000080466270447\n",
      "Epoch 1300, Train Cost 0.5000398755073547, Test Cost: 0.5000371336936951\n",
      "Epoch 1400, Train Cost 0.5000636577606201, Test Cost: 0.5000303387641907\n",
      "Epoch 1500, Train Cost 0.5000579953193665, Test Cost: 0.5000369548797607\n",
      "Epoch 1600, Train Cost 0.5000678896903992, Test Cost: 0.5000512003898621\n",
      "Epoch 1700, Train Cost 0.5000233054161072, Test Cost: 0.5000203251838684\n",
      "Epoch 1800, Train Cost 0.500093936920166, Test Cost: 0.5000514984130859\n",
      "Epoch 1900, Train Cost 0.5000334978103638, Test Cost: 0.5000484585762024\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAEWCAYAAABBvWFzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAH6VJREFUeJzt3Xu8VWW97/HPVxDvF5Blm5SbiZXdUGdkl91l57VSPJWK0g7LDtVOyzp20m1lWZ5DtU92Skup8JKkmOWOyo6SpWZby4UhgoYg3gjEVXjHVOB3/hjPgsFkzrnmWmvMOdaS7/v1Gq85xjOeZ4zfHBPmb43LfB5FBGZmZv21TdkBmJnZi4MTipmZFcIJxczMCuGEYmZmhXBCMTOzQjihmJlZIZxQzAoiKSTtm+YvlPSFsmMqg6Qxkp6WNKTsWKy9nFCsNJIekLRa0k65so9IujG3HJLukrRNruyrki6ps823S9qQvtDy0xtb+V6qRcTHIuIr7dznQBERD0XEzhGxvuxYrL2cUKxsQ4FP9VDnpcCUXmxzZfpCy0+39j3EgcF/8dtA54RiZfsGcLqk3RvU+TrwZUlD+7szSeMl3STpKUnzJJ0v6fK07u2SVlTVf0DSIWl+kqRbJT0uaVVqO6zOfi6R9NU0/4uqs6UNkk5K616R4lgjaYmk46q28T1J10p6BnhHE+/vRklfkfSH9B6vlzQyt/5oSYvTe7hR0isbbCskfUzSUkmPSbpAktK6bSR9XtKDkh6VdJmk3dK6cant0LR8kqTlKZ77JU3N7ePDku5J279O0tie3qMNXE4oVrZO4Ebg9AZ1fgY8CZxUwP5+DMwHRgJfAab1ou164NOp7RuBdwL/1lOjiDiq+0wJeD/wCHBDutQ3L8W0J3AC8F1Jr8o1PxE4F9gFuEXSiZIW9rDLE4EPpW0OIx1bSfsBVwCnAR3AtcAv6iXF5D3A64HXAccBh6fyk9L0DmAfYGfg/OrG6T1+GzgyInYB3gQsSOuOAf4deG+K5/cpPhuknFBsIPgicKqkjjrrA/gC8EVJ2zWxvZemv8Dz006SxpB9OX4hIp6LiJuBXzQbZETMj4jbImJdRDwAXAS8rdn26Qv9MuD4iHiY7Mv6gYi4OG3zDuCnZEmn288j4g8RsSEi/hERP46I1/awq4sj4t6IeBa4CpiYyo8HfhUR8yLiBeA/gB3IvuTrmRERj0fEQ8DvctuaCnwzIpZHxNPAmcCUOmeRG4BXS9ohIlZFxOJU/lHgf0fEPRGxDvhfwESfpQxeTihWuohYBPwSOKNBnWuBh4DpTWxyZUTsXjU9Q3Yv5rE03+3BZuOUtJ+kX0p6RNKTZF+AI3tql9ruBvycLJn9PhWPBd6QT3xkX9T/lGv6cLPx5TySm19LdvYA2fvf+H4jYkPa/l793VaaHwq8JN84HevjgY8BqyT9StIr0uqxwP/Nvfc1gHqIxwYwJxQbKM4G/juNv0w+D5wF7NjHfawChuefKgPG5OafyW873QTPnzV9D/gLMCEidiW7XKOedpqeUPsx8LuIuCi36mHgpqrEt3NEfDxXp8juwFeSfYl3xyVgNPDX/m6L7DiuA1ZXV4yI6yLiUGAU2fH7flr1MPDRqve/Q0T8Vx/isQHACcUGhIhYBswBPtmgzo3AXfTuvke+/YNk92y+LGmYpLcAR+Wq3AtsL+ndkrYlS2D5S2y7kN3LeTr9lZ3/4m/kXGAntnya7ZfAfpL+VdK2aXp9oxvl/XQV8G5J70zv738AzwF9+QK/Avh0eshhZ7KztTnp0tVGkl6SHgTYKe3rabJ7UQAXAmd23zOStJukY/v0zmxAcEKxgeQcsi/eRj4PjOihzku15e9Q3pfWnQi8gezyytlk9zQAiIgnyG6y/4Dsr/ZngPxTX6en9k+R/ZU9p6l3ld1sPxh4LBfP1Ih4CjiM7JHolWSXl77G5klsM5KmSlpcb30jEbEE+ADwHeBvZMn0qIh4vg+bmwX8CLgZuB/4B3BqjXrbkCWulWTH/G2kBxki4hqy93tluoS4CDiyD7HYACEPsGVbM0lfAvaNiA+UHYvZYOczFDMzK4QTipmZFcKXvMzMrBA+QzEzs0L0u2+kwWTkyJExbty4ssMwMxtU5s+f/7eIqNeTxUZbVUIZN24cnZ2dZYdhZjaoSGqqRwlf8jIzs0I4oZiZWSGcUMzMrBBOKGZmVggnFDMzK4QTSg9mz4Zx42CbbbLX2bPLjsjMbGDaqh4b7q3Zs2H6dFi7Nlt+8MFsGWDq1PrtzMy2Rj5DaeCsszYlk25r12blZma2OSeUBh56qHflZmZbMyeUBsaM6V25mdnWzAmlgXPPhR2rRi/fcces3MzMNueE0sDUqTBzJowdC1L2OnOmb8ibmdXip7x6MHWqE4iZWTN8hmJmZoVwQjEzs0KUmlAkzZL0qKRFddZL0rclLZO0UNKBuXXTJC1N07T2RW1mZrWUfYZyCXBEg/VHAhPSNB34HoCkEcDZwBuAScDZkoa3NFIzM2uo1IQSETcDaxpUmQxcFpnbgN0ljQIOB+ZFxJqIeAyYR+PEZGZmLVb2GUpP9gIezi2vSGX1yrcgabqkTkmdXV1dLQvUzGxrN9ATimqURYPyLQsjZkZEJSIqHR0dhQZnZmabDPSEsgIYnVveG1jZoNzMzEoy0BPKXOCD6Wmvg4EnImIVcB1wmKTh6Wb8YanMzMxKUuov5SVdAbwdGClpBdmTW9sCRMSFwLXAu4BlwFrgQ2ndGklfAW5PmzonIhrd3DczsxYrNaFExAk9rA/gE3XWzQJmtSIuMzPrvYF+ycvMzAYJJxQzMyuEE4qZmRXCCcXMzArhhGJmZoVwQjEzs0I4oZiZWSGcUMzMrBBOKGZmVggnFDMzK4QTipmZFcIJxczMCuGEYmZmhXBCMTOzQjihmJlZIZxQzMysEE4oZmZWCCcUMzMrhBOKmZkVotSEIukISUskLZN0Ro3150lakKZ7JT2eW7c+t25ueyM3M7NqQ8vasaQhwAXAocAK4HZJcyPi7u46EfHpXP1TgQNym3g2Iia2K14zM2uszDOUScCyiFgeEc8DVwKTG9Q/AbiiLZGZmVmvlZlQ9gIezi2vSGVbkDQWGA/8Nle8vaROSbdJOqbeTiRNT/U6u7q6iojbzMxqKDOhqEZZ1Kk7Bbg6ItbnysZERAU4EfiWpJfVahgRMyOiEhGVjo6O/kVsZmZ1lZlQVgCjc8t7Ayvr1J1C1eWuiFiZXpcDN7L5/RUzM2uzMhPK7cAESeMlDSNLGls8rSXp5cBw4NZc2XBJ26X5kcCbgbur25qZWfuU9pRXRKyTdApwHTAEmBURiyWdA3RGRHdyOQG4MiLyl8NeCVwkaQNZUpyRfzrMzMzaT5t/T7+4VSqV6OzsLDsMM7NBRdL8dM+6If9S3szMCuGEYmZmhXBCMTOzQjihmJlZIZxQzMysEE4oZmZWCCcUMzMrhBOKmZkVwgnFzMwK4YRiZmaFcEIxM7NCOKGYmVkhnFDMzKwQTihmZlYIJxQzMyuEE4qZmRXCCcXMzArhhGJmZoVwQjEzs0KUmlAkHSFpiaRlks6osf4kSV2SFqTpI7l10yQtTdO09kZuZmbVhpa1Y0lDgAuAQ4EVwO2S5kbE3VVV50TEKVVtRwBnAxUggPmp7WNtCN3MzGoo8wxlErAsIpZHxPPAlcDkJtseDsyLiDUpicwDjmhRnGZm1oQyE8pewMO55RWprNr7JC2UdLWk0b1si6TpkjoldXZ1dRURt5mZ1VBmQlGNsqha/gUwLiJeC/wGuLQXbbPCiJkRUYmISkdHR5+DNTOzxspMKCuA0bnlvYGV+QoR8feIeC4tfh84qNm2ZmbWXmUmlNuBCZLGSxoGTAHm5itIGpVbPBq4J81fBxwmabik4cBhqczMzErS41NeksZExEM9lfVWRKyTdApZIhgCzIqIxZLOATojYi7wSUlHA+uANcBJqe0aSV8hS0oA50TEmv7EY2Zm/aOImrceNlWQ7oiIA3sqGwwqlUp0dnaWHYaZ2aAiaX5EVHqqV/cMRdJ+wCuB3dJZQrddge37H6KZmb2YNLrk9SrgvcDuwLG58qeAj7YyKDMzG3zqJpSIuAa4RtJbIuKWNsZkZmaDUDNPeb1b0q6Shkq6TtJqSSe2PDIzMxtUmkkoR0bEk8B7gEeBVwOfa2lUZmY26DSTULZNr+8CroiILur8Kt3MzLZezfQ2fK2kRcB64BOSRgLP9dDGzMy2Mj2eoUTEZ4F/AQ6KiBeAZ8me/jIzM9uomV/KDwXeD7xVEsBNZP1qmZmZbdTMJa8LgJ2AWWn5A8ABwPRWBWVmZoNPMwnl4Ih4XW75ekl3tiogMzMbnJp5ymuDpHHdC2l+Q2vCMTOzwaqZM5T/Cdws6V6yga32BU5uaVRmZjbo9JhQImKepJeTdRQp4O6IeLblkZmZ2aDSqLfhE4AhEXF5SiB3pPKPSHoqIua0K0gzMxv4Gt1D+SxVIygmPyG7DGZmZrZRo4QyNPXhtZmIeIJN3bGYmZkBjRPKMEk7VhdK2hnYrnUhmZnZYNQoocwCfiJp7+6CNP9j4OIidi7pCElLJC2TdEaN9Z+RdLekhZJukDQ2t269pAVpqnVpzszM2qjRAFtfl7QW+GPqfgXgBWBGRJzf3x1LGkL2K/xDgRXA7ZLmRsTduWp/BioRsVbSx4GvA8endc9GxMT+xmFmZsVo+NhwShznS9odUEQ8VuC+JwHLImI5gKQrgcnAxoQSEb/L1b+NrNsXMzMbgJr5pTwR8XjByQRgL+Dh3PKKVFbPycCvc8vbS+qUdJukY+o1kjQ91evs6urqX8RmZlZXM7+UbxXVKKs5cJekDwAV4G254jERsVLSPsBvJd0VEfdtscGImcBMgEql4oHBzMxapMczlNz9k4ZlfbACGJ1b3htYWWNfhwBnAUdHxMaBvSJiZXpdDtxI1gOymZmVpJlLXn9qsqy3bgcmSBovaRgwhaofUko6ALiILJk8misfLmm7ND8SeDO5ey9mZtZ+jbpe2RMYBewg6TVsukS1K7DF71N6KyLWSToFuA4YAsyKiMWSzgE6I2Iu8A1gZ7LHlwEeioijyfoVu0jSBrKkOKPq6TAzM2uzRpeu3g18mOxS1AVsSihPAV8oYucRcS1wbVXZF3Pzh9Rp91/Aa4qIwczMitHodygXAxdLOi4irmpjTGZmNgg1cw9lT0m7Aki6UNKfJL2zxXGZmdkg00xCmR4RT0o6jOzyV/cv1s3MzDZqJqF0/3bjSODiiJjfZDszM9uKNJMY7pR0LXAU8OvU27B/IGhmZptp5geKHwIOIut3a2363YfHlDczs830eIYSEeuBfcjunQDs0Ew7MzPbujTT9cr5wDvY1NPvM8CFrQzKzMwGn2Yueb0pIg6U9GeAiFiTukoxMzPbqJlLVy9I2oZ0I17SHsCGlkZlZmaDTjMJ5QLgp0CHpC8DtwBfa2lUZmY26DTqHHJoRKyLiMskzQcOIevP69iIWNS2CM3MbFBodA/lT8CBABGxGFjclojMzGxQanTJq9aIimZmZjU1OkPpkPSZeisj4pstiMfMzAapRgllCNngVj5TMTOzHjVKKKsi4py2RWJmZoOa76GYmVkhGiUUD6JlZmZNq5tQImJNq3cu6QhJSyQtk3RGjfXbSZqT1v9R0rjcujNT+RJJh7c6VjMza6y0XoMlDSH7Ff6RwP7ACZL2r6p2MvBYROwLnEf6hX6qNwV4FXAE8N20PTMzK0mZ3dBPIhtjZXlEPA9cCUyuqjMZuDTNXw28U5JS+ZUR8VxE3A8sS9szM7OS1E0okkZLulLS7yX9u6Rtc+v+s4B97wU8nFtekcpq1omIdcATwB5Ntu2OdbqkTkmdXV1dBYRtZma1NDpDmQXcCJwKjAJuSj0NA4wtYN+1niKrHlq4Xp1m2maFETMjohIRlY6Ojl6GaGZmzWqUUDoi4sKIWBARpwLfBW6W9DKKGVN+BTA6t7w3sLJeHUlDgd2ANU22NTOzNmqUULaVtH33QkRcDnwKuI7sjKW/bgcmSBqfBuyaAsytqjMXmJbm3w/8NiIilU9JT4GNByaQdWZpZmYlaZRQfgC8IV8QEb8BjgX63X19uidyClmCuge4KiIWSzpH0tGp2g+BPSQtAz4DnJHaLgauAu4G/h/wiYhY39+YzMys75T9wd/LRtJpEfGtFsTTUpVKJTo7O8sOw8xsUJE0PyIqPdXr62PDdXshNjOzrVNfE4r7+TIzs830NaEU8ZSXmZm9iDQaU/4paicOATu0LCIzMxuU6iaUiNilnYGYmdngVmZfXmZm9iLihGJmZoVwQjEzs0I4oZiZWSGcUMzMrBBOKGZmVggnFDMzK4QTipmZFcIJxczMCuGEYmZmhXBCMTOzQjihmJlZIZxQzMysEE4oZmZWiFISiqQRkuZJWppeh9eoM1HSrZIWS1oo6fjcuksk3S9pQZomtvcdmJlZtbLOUM4AboiICcANabnaWuCDEfEq4AjgW5J2z63/bERMTNOC1odsZmaNlJVQJgOXpvlLgWOqK0TEvRGxNM2vBB4FOtoWoZmZ9UpZCeUlEbEKIL3u2aiypEnAMOC+XPG56VLYeZK2a9B2uqROSZ1dXV1FxG5mZjW0LKFI+o2kRTWmyb3czijgR8CHImJDKj4TeAXwemAE8Ll67SNiZkRUIqLS0eETHDOzVqk7pnx/RcQh9dZJWi1pVESsSgnj0Tr1dgV+BXw+Im7LbXtVmn1O0sXA6QWGbmZmfVDWJa+5wLQ0Pw34eXUFScOAa4DLIuInVetGpVeR3X9Z1NJozcysR2UllBnAoZKWAoemZSRVJP0g1TkOeCtwUo3Hg2dLugu4CxgJfLW94ZuZWTVFRNkxtE2lUonOzs6ywzAzG1QkzY+ISk/1/Et5MzMrhBOKmZkVwgnFzMwK4YRiZmaFcEIxM7NCOKGYmVkhnFDMzKwQTihmZlYIJxQzMyuEE4qZmRXCCcXMzArhhGJmZoVwQjEzs0I4oZiZWSGcUMzMrBBOKGZmVggnFDMzK4QTipmZFaKUhCJphKR5kpam1+F16q3PjSc/N1c+XtIfU/s5koa1L3ozM6ulrDOUM4AbImICcENaruXZiJiYpqNz5V8DzkvtHwNObm24ZmbWk7ISymTg0jR/KXBMsw0lCfgX4Oq+tDczs9YoK6G8JCJWAaTXPevU215Sp6TbJHUnjT2AxyNiXVpeAezV2nDNzKwnQ1u1YUm/Af6pxqqzerGZMRGxUtI+wG8l3QU8WaNeNIhjOjAdYMyYMb3YtZmZ9UbLEkpEHFJvnaTVkkZFxCpJo4BH62xjZXpdLulG4ADgp8Dukoams5S9gZUN4pgJzASoVCp1E4+ZmfVPWZe85gLT0vw04OfVFSQNl7Rdmh8JvBm4OyIC+B3w/kbtzcysvcpKKDOAQyUtBQ5Ny0iqSPpBqvNKoFPSnWQJZEZE3J3WfQ74jKRlZPdUftjW6M3MbAvK/uDfOlQqlejs7Cw7DDOzQUXS/Iio9FTPv5Q3M7NCOKGYmVkhnFDMzKwQTihmZlYIJxQzMyuEE4qZmRXCCcXMzArhhGJmZoVwQjEzs0I4oZiZWSGcUMzMrBBOKGZmVggnFDMzK4QTSovNng3jxsE222Svs2e7vdu7vdsPjva9FhFbzXTQQQdFO11+ecSOO0bApmnHHbNyt3d7t3f7gdw+D+iMJr5jS/+Sb+fU7oQyduzmH2b3NHas27u927v9wG6f12xC8QBbLbTNNtlHWE2CDRvc3u3d3u0HbvvN23iArdKNGdO7crd3e7d3+4HSvk+aOY15sUy+h+L2bu/2bt9c+zwG8j0UYAQwD1iaXofXqPMOYEFu+gdwTFp3CXB/bt3EZvbb7oQSkX14Y8dGSNlrbz9Mt3d7t3f7stp3azahlHIPRdLXgTURMUPSGSmhfK5B/RHAMmDviFgr6RLglxFxdW/22+57KGZmLwYD/R7KZODSNH8pcEwP9d8P/Doi1rY0KjMz67OyEspLImIVQHrds4f6U4ArqsrOlbRQ0nmStqvXUNJ0SZ2SOru6uvoXtZmZ1dWyhCLpN5IW1Zgm93I7o4DXANflis8EXgG8nux+TN3LZRExMyIqEVHp6OjowzsxM7NmDG3VhiPikHrrJK2WNCoiVqWE8WiDTR0HXBMRL+S2vSrNPifpYuD0QoI2M7M+K+uS11xgWpqfBvy8Qd0TqLrclZIQkkR2/2VRC2I0M7NeKOsprz2Aq4AxwEPAsRGxRlIF+FhEfCTVGwf8ARgdERty7X8LdAAie2z4YxHxdBP77QIe7GPYI4G/9bFtOzi+/nF8/eP4+megxzc2Inq8Z7BVdb3SH5I6m3lsriyOr38cX/84vv4Z6PE1y12vmJlZIZxQzMysEE4ozZtZdgA9cHz94/j6x/H1z0CPrym+h2JmZoXwGYqZmRXCCcXMzArhhNIESUdIWiJpWeodud37Hy3pd5LukbRY0qdS+Zck/VXSgjS9K9fmzBTvEkmHtynOByTdlWLpTGUjJM2TtDS9Dk/lkvTtFONCSQe2MK6X547RAklPSjqt7OMnaZakRyUtypX1+nhJmpbqL5U0rda+CozvG5L+kmK4RtLuqXycpGdzx/LCXJuD0r+LZek9qIXx9fozbdX/7zrxzcnF9oCkBam87cevJZrp435rnoAhwH3APsAw4E5g/zbHMAo4MM3vAtwL7A98CTi9Rv39U5zbAeNT/EPaEOcDwMiqsq8DZ6T5M4Cvpfl3Ab8m+3HqwcAf2/h5PgKMLfv4AW8FDgQW9fV4kfVltzy9Dk/zW4wvVGB8hwFD0/zXcvGNy9er2s6fgDem2H8NHNnC+Hr1mbby/3et+KrW/x/gi2Udv1ZMPkPp2SRgWUQsj4jngSvJut9vm4hYFRF3pPmngHuAvRo0mQxcGRHPRcT9ZGPJTGp9pHVjqTVUwWTgssjcBuyu1KVOi70TuC8iGvWY0JbjFxE3A2tq7Ls3x+twYF5ErImIx8gGrDuiVfFFxPURsS4t3gbs3WgbKcZdI+LWyL4dL6Pn4Sr6HF8D9T7Tlv3/bhRfOss4ji17Ua+u17Lj1wpOKD3bC3g4t7yCxl/mLaWsO5oDgD+molPS5YdZ3ZdHKC/mAK6XNF/S9FRWb6iCsmKsHgphIB0/6P3xKjPWD5P9xdxtvKQ/S7pJ0j+nsr1STO2MrzefaVnH75+B1RGxNFc2UI5fnzmh9KzW9cpSnrWWtDPwU+C0iHgS+B7wMmAisIrsFBrKi/nNEXEgcCTwCUlvbVC37TFKGgYcDfwkFQ2049dIvZhKiVXSWcA6YHYqWgWMiYgDgM8AP5a0awnx9fYzLeuzru70dqAcv35xQunZCmB0bnlvYGW7g5C0LVkymR0RPwOIiNURsT6yjjO/z6bLMqXEHBEr0+ujwDUpntXa1Dt0fqiCMmI8ErgjIlanOAfU8Ut6e7zaHmu68f8eYGq6DEO6lPT3ND+f7L7Efim+/GWxlsbXh8+0jOM3FHgvMCcX94A4fv3lhNKz24EJksanv3CnkHW/3zbpeusPgXsi4pu58vw9h//Gpm785wJTJG0naTwwgezGXitj3EnSLt3zZDdvF1F/qIK5wAfT00sHA0/EpnFuWmWzvwoH0vHL6e3xug44TNLwdHnnMDYfjK5Qko4gG9Du6MgNyS2pQ9KQNL8P2TFbnmJ8StLB6d/xB2k8XEV/4+vtZ1rG/+9DgL9ExMZLWQPl+PVb2U8FDIaJ7Ambe8n+ajirhP2/hew0dyFZd/0LUkw/Au5K5XOBUbk2Z6V4l9CGp0LInpK5M02Lu48TsAdwA7A0vY5I5QIuSDHeBVRaHN+OwN+B3XJlpR4/suS2CniB7C/Rk/tyvMjuZSxL04daHN8ysnsO3f8OL0x135c+9zuBO4CjctupkH2x3wecT+qho0Xx9fozbdX/71rxpfJLyIbcyNdt+/FrxeSuV8zMrBC+5GVmZoVwQjEzs0I4oZiZWSGcUMzMrBBOKGZmVggnFNvqSNoj16vrI1W90w5rchsXS3p5D3U+IWlqQTHfknrE7Y5zTs+terX9FUo9B5v1lR8btq2apC8BT0fEf1SVi+z/x4ZSAqsi6RbglIhY0KLtrwBeHRGPt2L7tnXwGYpZImlfSYvSWBR3AKMkzZTUqWwcmi/m6t4iaaKkoZIelzRD0p2SbpW0Z6rzVUmn5erPkPSndKbxplS+k6SfprZXpH1N7EXMl0v6nqTfS7pX0pGpfAdJlyobR+OO7n7VUrznpfe5UNK/5TZ3mrLOCRdK2q/fB9S2Ok4oZpvbH/hhRBwQEX8lG5ukArwOOFTS/jXa7AbcFBGvA24l++V6LYqIScBnge7kdCrwSGo7g6wn6XrygzPNyJWPBt4GHAXMlLQd8Eng+Yh4DfCvwI/S5byPAy8FXhcRryXrrr3b6sg6J/wBWQeFZr0ytOwAzAaY+yLi9tzyCZJOJvu/8lKyhHN3VZtnI6K7G/f5ZF2T1/KzXJ1xaf4tZANVERF3SlrcILbj61zyuipdmlsi6WGyfqDeAnwjbXexpJXAvmT9SH0rItandfnxOvLxvQuzXnJCMdvcM90zkiYAnwImRcTjki4Htq/R5vnc/Hrq/796rkadIoZzrb4RWq9b9u791btxWis+s6b5kpdZfbsCTwFPatPoiEW7hWzkPiS9huwMqLeOTb0Q70d2+WspcDMwNW33lWTDSC8Drgc+nuvZdkS/34FZ4r9CzOq7g+zy1iKysdr/0IJ9fAe4TNLCtL9FwBN16s6R9GyaXx0R3QluGVkC2ROYHhHPS/oOcJGku8h6u/1gKr+I7JLYQknryAakurAF78u2Qn5s2KxEygZbGhoR/0iX2K4HJsSmcdt7an85cHVE/Gcr4zRrhs9QzMq1M3BDSiwCPtpsMjEbaHyGYmZmhfBNeTMzK4QTipmZFcIJxczMCuGEYmZmhXBCMTOzQvx/FiWRcLplcdUAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7fcb3cf53c88>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "### CFO then Equalizer ###\n",
    "\n",
    "import matplotlib.pyplot as plt   \n",
    "import numpy as np\n",
    "\n",
    "epochs = 2000\n",
    "batch_size = 1500\n",
    "\n",
    "#     # generate the data\n",
    "#     (channel_train, channel_test, preamble_train_orig, preamble_train_cfo, preamble_test_orig, \n",
    "# preamble_test_cfo, data_train_orig,data_train_cfo,data_test_orig,data_test_cfo) = data_generation(epochs)\n",
    "\n",
    "mc_losses = []\n",
    "\n",
    "# define equalizer NN class object\n",
    "net2 = CFO_Equalizer(preamble_length = 50, data_length = 100, channel_length = 2, \n",
    "                    batch_size=1500, learning_rate = 0.01)\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    current_start = epoch*batch_size\n",
    "    channel_train_batch = channel_train[current_start:(current_start+batch_size),:,:]\n",
    "    data_train_cfo_batch = data_train_cfo[current_start:(current_start+batch_size),:,:]\n",
    "    data_train_orig_batch = data_train_orig[current_start:(current_start+batch_size),:,:]\n",
    "    preamble_train_orig_batch = preamble_train_orig[current_start:(current_start+batch_size),:,:]\n",
    "    preamble_train_cfo_batch = preamble_train_cfo[current_start:(current_start+batch_size),:,:]        \n",
    "\n",
    "    train_cost = net2.train_net(preamble_train_orig_batch, preamble_train_cfo_batch, data_train_orig_batch, \n",
    "                               data_train_cfo_batch, channel_train_batch)\n",
    "\n",
    "    if epoch % 100 == 0: \n",
    "\n",
    "\n",
    "        test_cost, test_est, test_omega = net2.test_net(preamble_test_orig, \n",
    "                                                                  preamble_test_cfo, data_test_orig, \n",
    "                                                                  data_test_cfo, channel_test)\n",
    "\n",
    "        mc_losses.append(test_cost)\n",
    "\n",
    "\n",
    "        plt.plot(epoch, np.log(test_cost), 'bo')\n",
    "        print('Epoch {}, Train Cost {}, Test Cost: {}'.format(epoch, train_cost,test_cost))\n",
    "\n",
    "\n",
    "plt.xlabel('Training Epoch')\n",
    "plt.ylabel('L2 Test Cost')\n",
    "plt.title('NN Equalizer: no noise')\n",
    "# plt.text(1000, .025, r'NN equalizer')\n",
    "# plt.text(0.5, .025, r'Zero Force equalizer')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Train Cost 0.9539691805839539, Test Cost: 0.6533547639846802\n",
      "Epoch 100, Train Cost 0.4940091073513031, Test Cost: 0.49371567368507385\n",
      "Epoch 200, Train Cost 0.4182162582874298, Test Cost: 0.4144271910190582\n",
      "Epoch 300, Train Cost 0.23925510048866272, Test Cost: 0.2369288057088852\n",
      "Epoch 400, Train Cost 0.08440049737691879, Test Cost: 0.08010869473218918\n",
      "Epoch 500, Train Cost 0.5377944111824036, Test Cost: 0.5361482501029968\n",
      "Epoch 600, Train Cost 0.49668964743614197, Test Cost: 0.49710613489151\n",
      "Epoch 700, Train Cost 0.4910082221031189, Test Cost: 0.4900324046611786\n",
      "Epoch 800, Train Cost 0.20533324778079987, Test Cost: 0.22462791204452515\n",
      "Epoch 900, Train Cost 0.4886311888694763, Test Cost: 0.4878615140914917\n",
      "Epoch 1000, Train Cost 0.1544889360666275, Test Cost: 0.14943507313728333\n",
      "Epoch 1100, Train Cost 0.03892027214169502, Test Cost: 0.03928423300385475\n",
      "Epoch 1200, Train Cost 0.5218047499656677, Test Cost: 0.5198233127593994\n",
      "Epoch 1300, Train Cost 0.4949823021888733, Test Cost: 0.49471786618232727\n",
      "Epoch 1400, Train Cost 0.4837802052497864, Test Cost: 0.48348864912986755\n",
      "Epoch 1500, Train Cost 0.10188066214323044, Test Cost: 0.10091014951467514\n",
      "Epoch 1600, Train Cost 0.15189984440803528, Test Cost: 0.14673158526420593\n",
      "Epoch 1700, Train Cost 0.04191119223833084, Test Cost: 0.04136177524924278\n",
      "Epoch 1800, Train Cost 0.03373783826828003, Test Cost: 0.03350622579455376\n",
      "Epoch 1900, Train Cost 0.030679849907755852, Test Cost: 0.030812356621026993\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEWCAYAAABxMXBSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3XucHGWd7/HPN4TbhLsJyG1mUAIrq4Iwi/cLB0QSBbwsCo4rXjgjruBljx5xR0HFnBNlj7tnxRVHBUUGQY8iKHFJ8AiIijjBAOEeYhJiAkQBCQyCYX77Rz1jOkN3T9f0dFf3zPf9evWrqp5+qurX1TP963qe6qcUEZiZmdVqRtEBmJlZe3HiMDOzXJw4zMwsFycOMzPLxYnDzMxyceIwM7NcnDjMcpIUkvZP8+dJ+lTRMRVBUqekxyRtVXQs1lxOHNZwklZJekDSrJKyUyRdU7Ickm6VNKOk7HOSvllhm6+RNJI+uEofL23kaxkrIk6NiLObuc9WERFrImKHiHi66FisuZw4rFlmAh8ap85ewIk5trkufXCVPn418RBbg7/BW6tz4rBmOQf4qKRdqtT5AvAZSTPr3Zmk/SRdK2mjpCWSzpV0UXruNZLWjqm/StJRaf5wSb+S9Iik9WndbSrs55uSPpfmfzTm7GdE0rvSc3+T4nhI0l2S3jpmG1+RtEjS48ARNby+aySdLekX6TUuljS75PnjJN2WXsM1kp5XZVsh6VRJ90h6WNKXJSk9N0PSJyWtlvSgpAsl7Zye607rzkzL75K0MsXzO0m9Jft4j6Q70vavktQ13mu01uXEYc0yBFwDfLRKnR8AjwLvmoT9XQwsBWYDZwMn51j3aeAjad2XAkcC/zjeShFx7OiZD/D3wP3AT1MT3ZIU0+7AScB/SPrbktXfDiwAdgSul/R2SbeMs8u3A+9O29yGdGwlHQB8B/gwMAdYBPyoUvJL3gD8HXAw8Fbgdan8XelxBPAcYAfg3LErp9f478C8iNgReBmwLD33RuCfgTeneH6e4rM25cRhzXQmcLqkORWeD+BTwJmStq1he3ulb9Slj1mSOsk+BD8VEU9GxHXAj2oNMiKWRsQNEbEpIlYBXwVeXev66YP7QuBtEXEf2Yfyqoi4IG3zJuD7ZMll1OUR8YuIGImIP0fExRHxwnF2dUFE3B0RTwDfBQ5J5W8DroyIJRHxF+BfgO3JPswrWRgRj0TEGuBnJdvqBb4YESsj4jHgE8CJFc4KR4DnS9o+ItZHxG2p/H3A/46IOyJiE/C/gEN81tG+nDisaSJiOfBj4IwqdRYBa4C+Gja5LiJ2GfN4nKyv5OE0P2p1rXFKOkDSjyXdL+lRsg+62eOtl9bdGbicLGn9PBV3AS8uTXBkH8jPLln1vlrjK3F/yfww2dkAZK//r683IkbS9veud1tpfiawR+nK6Vi/DTgVWC/pSkl/k57uAv5vyWt/CNA48VgLc+KwZjsL+O9U/9D4JNAPdExwH+uBXUuv4gI6S+YfL9126owuPQv6CnAnMDcidiJrZtF4O01XhF0M/Cwivlry1H3AtWMS3A4R8f6SOpM5TPU6sg/r0bgE7Av8vt5tkR3HTcADYytGxFUR8VpgT7Lj97X01H3A+8a8/u0j4pcTiMdagBOHNVVErAAuBT5Ypc41wK3k65coXX81WZ/KZyRtI+kVwLElVe4GtpP0eklbkyWq0qaxHcn6Wh5L35pLP+CrWQDM4plXj/0YOEDSP0jaOj3+rlqHdZ2+C7xe0pHp9f0P4ElgIh/U3wE+ki422IHs7OvS1OT0V5L2SB3ys9K+HiPrKwI4D/jEaJ+OpJ0lnTChV2YtwYnDivBZsg/Yaj4J7DZOnb30zN9xvCU993bgxWTNImeR9TkAEBF/Iuvs/jrZt/DHgdKrrD6a1t9I9q350ppeVdbp/RLg4ZJ4eiNiI3A02aXG68iahT7PlslqC5J6Jd1W6flqIuIu4B3Al4A/kCXNYyPiqQls7nzg28B1wO+APwOnl6k3gyxBrSM75q8mXVAQEZeRvd5LUtPfcmDeBGKxFiHfyMmmA0mfBvaPiHcUHYtZu/MZh5mZ5eLEYWZmubipyszMcvEZh5mZ5VL3mECtaPbs2dHd3V10GGZmbWPp0qV/iIhKozpsYUomju7uboaGhooOw8ysbUiqeXQFN1WZmVkuThxmZpaLE4eZmeXixGFmZrk4cZiZWS5OHMngIHR3w4wZ2XRwsOiIzMxa05S8HDevwUHo64Ph4Wx59epsGaC3t/J6ZmbTkc84gP7+zUlj1PBwVm5mZlty4gDWrMlXbmY2nTlxAJ2d+crNzKazQhKHpN0kLZF0T5ruWqHe05KWpccVjYpnwQLoGHN3646OrNzMzLZU1BnHGcBPI2Iu8NO0XM4TEXFIehzXqGB6e2FgALq6QMqmAwPuGDczK6eQ+3FIugt4TUSsl7QncE1EHFim3mMRsUPe7ff09IQHOTQzq52kpRHRU0vdos449oiI9QBpunuFettJGpJ0g6Q3VtugpL5Ud2jDhg2THa+ZmSUN+x2HpKuBZ5d5Ks9Frp0RsU7Sc4D/L+nWiLi3XMWIGAAGIDvjyB2wmZnVpGGJIyKOqvScpAck7VnSVPVghW2sS9OVkq4BXgSUTRxmZtYcRTVVXQGcnOZPBi4fW0HSrpK2TfOzgZcDtzctQjMzK6uoxLEQeK2ke4DXpmUk9Uj6eqrzPGBI0s3Az4CFEeHEYWZWsEISR0T8MSKOjIi5afpQKh+KiFPS/C8j4gURcXCafqOIWK02HiTSbPrwIIdWNw8SaTa9eMgRq5sHiWx/PmO0PHzGYXXzIJHtzWeMlpfPOKxuHiSyvfmM0fJy4pgk0/lUfzIGiZzOx69ok3HG6PdvmomIKfc47LDDopkuuiiioyMCNj86OrLy6eKiiyK6uiKkbJrntU/G8atn/5Oh6P3Xo6try2M/+ujqqm19//1PDcBQ1PgZW/iHfCMezU4c9f7jTXft/sFV9P7rVW/8k/H3386Jd6rIkzgKGR230Zo9Ou6MGdm/ylgSjIw0LYy2Ve/x6+7OOnTH6uqCVavqja719z8ZBgezPo01a7K+qQULau8Yr/f9G9s5D1lTp29t0Fx5Rsd14pgEU+GDo0j1Hr+iE3fR+y9ave+f/39aQzsMqz6l+A6C9an3+BV9VVfR+y9ave+fL+duP04ck8B3EKxPvcev6MRd9P6LVu/7N90Tb1uqtTOknR7N7hy34hXduVr0/ttZu19cMFXgznHfOtasndTTOW+TI08fh4ccMbPC9fY6UbQT93GYmVkuThxmZpaLE0eL8Fg/ZtYu3MfRAjystZm1E59xtAAPa21m7cSJowX4l7Nm1k6cOFqAfzlrZu3EiaMFTPchK8ysvThxtACPdWVm7aSQxCHpBEm3SRqRVPEn7pKOkXSXpBWSzmhmjM3W25sNIT0ykk2dNMysVRV1xrEceDNwXaUKkrYCvgzMAw4CTpJ0UHPCMzOzSgr5HUdE3AEgqVq1w4EVEbEy1b0EOB64veEBmplZRa3cx7E3cF/J8tpUVpakPklDkoY2bNjQ8ODMzKarhp1xSLoaeHaZp/oj4vJaNlGmrOIY8BExAAxANqx6TUGamVluDUscEXFUnZtYC+xbsrwPsK7ObZqZWZ1auanqN8BcSftJ2gY4Ebii4JjMzKa9oi7HfZOktcBLgSslXZXK95K0CCAiNgGnAVcBdwDfjYjbiojXzMw2K+qqqsuAy8qUrwPmlywvAhY1MTQzMxtHKzdVmZlZC3LiMDOzXJw4zMwsFycOMzPLxYnDzMxyceIwM7NcnDjMzCwXJw4zM8vFicPMzHJx4jAzs1ycOMzMLBcnDjMzy8WJw8zMcnHiMDOzXJw4zMwsFycOMzPLxYnDzMxyceIwM7NcnDjMzCwXJw4zM8vFiWOKGByE7m6YMSObDg4WHZGZTVUziw7A6jc4CH19MDycLa9enS0D9PYWF5eZTU0+45gC+vs3J41Rw8NZuZnZZHPimALWrMlXbs/kpj6z2hWSOCSdIOk2SSOSeqrUWyXpVknLJA01M8Z20tmZr9y2NNrUt3o1RGxu6nPyMCuvqDOO5cCbgetqqHtERBwSERUTzHS3YAF0dGxZ1tGRldv43NRnlk8hiSMi7oiIu4rY91TU2wsDA9DVBVI2HRhwx3it3NRnls+4iUPSMxo8ypU1SACLJS2V1FetoqQ+SUOShjZs2NCk8FpHby+sWgUjI9nUSaN2buozy6eWM44f1li2BUlXS1pe5nF8jvheHhGHAvOAD0h6VaWKETEQET0R0TNnzpwcu7Dpzk19ZvlU/B2HpAOA5wE7Szqu5KmdgO3G23BEHFVvcBGxLk0flHQZcDi19YuY1Wz07Ky/P2ue6uzMkobP2szKq/YDwL8l68DeBTihpHwj8L5GBgUgaRYwIyI2pvmjgc82er82PfX2OlGY1api4oiIy4DLJL0iIq6fzJ1KehPwJWAOcKWkZRHxOkl7AV+PiPnAHmn/o3FeHBH/OZlxmJlZfrUMOfJ6SbcAw8CVwCHARyLi4onudDQplSlfB8xP8yuBgye6DzMza4xaOsfnRcSjwBuAB4HnAx9vaFRmZtayakkcW6fpfOA7EbGB7DJZMzObhmppqlokaTnwNNklsbOBJxsblpmZtapxzzgi4mPAfwMOi4i/AE+QXW1lZmbT0LhnHJJmAn8PvCpd4XQt8LUGx2VmZi2qlqaqLwOzgPPT8juAFwFVhwAxM7OpqZbE8ZKIKL0sdrGkmxsVkJmZtbZarqoakdQ9upDmRxoTjpmZtbpazjj+J3CdpLsBAfsD721oVGZm1rLGTRwRsUTSgWQDHgq4PSKeaHhkZmbWkqqNjnsSsFVEXJQSxU2p/BRJGyPi0mYFaWZmraNaH8fHgCvKlH+PrPnKzMymoWqJY2Yao2oLEfEnNg9DYmZm00y1xLGNpI6xhZJ2ALZtXEhmZtbKqiWO84HvSdpntCDNXwxc0OjAzMysNVW7kdMXJA0Dv07DjgD8BVgYEec2JTozM2s5VS/HTQniXEm7AIqIh5sTlpmZtapafjlORDzipGHWOIOD0N0NM2Zk08HBoiMyq6ymxGFmjTM4CH19sHo1RGTTvj4nj2Zy4s5n3MRR0r9RtczMJqa/H4aHtywbHs7KrfGcuPOr5YzjxhrLzGwC1qzJV26Ty4k7v2pDjuwO7AlsL+kFZONUAewEPOP3HWY2MZ2d2bfccuXWeE7c+VVrcno98B5gH7KbOY0mjo3Apxocl9m0sWBB1jRS+q23oyMrt8Zz4s6vYlNVRFwQEa8E3hsRr4qIV6bH/Ij4Xj07lXSOpDsl3SLpsnS5b7l6x0i6S9IKSWfUs0+zVtXbCwMD0NUFUjYdGMjKrfEWLMgSdSkn7upq6ePYXdJOAJLOk3SjpCPr3O8S4PkR8ULgbuATYytI2orsTGcecBBwkqSD6tyvWUvq7YVVq2BkJJs6aTSPE3d+tSSOvoh4VNLRZM1W7we+UM9OI2JxRGxKizek7Y51OLAiIlZGxFPAJcDx9ezXzKwcJ+58akkckabzgAsiYmmN69XqPcBPypTvDdxXsrw2lZUlqU/SkKShDRs2TGJ4ZmZWqpYEcLOkRcCxwE/S6LgxzjpIulrS8jKP40vq9AObgHJXTKtMWcX9RsRARPRERM+cOXPGfVFmZjYxtfyQ793AYWTNRsOSZlPDPccj4qhqz0s6GXgDcGRElEsIa4F9S5b3AdbVEK+ZmTXQuGccEfE08Byyvg2A7WtZrxpJxwAfB46LiOEK1X4DzJW0n6RtgBMpf0dCMzNrolqGHDkXOAJ4Ryp6HDivzv2eC+wILJG0TNJ5aV97pWYxUuf5acBVwB3AdyPitjr3a2ZmdaqlqeplEXGopN8CRMRD6QxgwiJi/wrl64D5JcuLgEX17MvMzCZXLU1Of5E0g9QxLelZwEhDozIzs5ZVS+L4MvB9YI6kzwDXA59vaFRmZtayqg1yODMiNkXEhZKWAkeRXSJ7QkQsb1qEZmbWUqr1cdwIHAqQOqXdMW1mZlWbqsr9AM/MzKa5amcccyT9U6UnI+KLDYjHzMxaXLXEsRWwAz7zMDOzEtUSx/qI+GzTIjEzs7bgPg4zM8ulWuKo92ZNZmY2BVW7dexDzQzEzGyiBgehuxtmzMimg+Vu1GCTppaxqszMWtbgIPT1wXAaZ3v16mwZfCe/RpnMO/mZmTVdf//mpDFqeDgrt8aomDgk7SvpEkk/l/TPkrYuee6HzQnPzKy6NWvylVv9qp1xnA9cA5wO7Alcm0bGBehqcFxmZjXp7MxXbvWrljjmRMR5EbEsIk4H/gO4TtJzqeGe42ZmzbBgAXR0bFnW0ZGVW2NU6xzfWtJ2EfFngIi4SNL9ZHfkm9WU6MzMxjHaAd7fnzVPdXZmScMd441T7Yzj68CLSwsi4mrgBMDDqptZy+jthVWrYGQkmzppNFbFM46I+NcK5b+VdGXjQjIzs1Y20ctxK46aa2ZmU9tEE4fHsTIzm6Ymmjh8VZWZ2TRV7Z7jGymfIARs37CIzMyspVXrHN+xUTuVdA5wLPAUcC/w7oh4pEy9VcBG4GlgU0T0NComMzOrTVFjVS0Bnh8RLwTuBj5Rpe4REXGIk4aZWWsoJHFExOKI2JQWbwD2KSIO28zDUptZrVphdNz3AD+p8FwAiyUtldRXbSOS+iQNSRrasGHDpAc5lY0OS716NURsHpbaycPMylFEYy6QknQ18OwyT/VHxOWpTj/QA7w5ygQiaa+IWCdpd7LmrdMj4rrx9t3T0xNDQ0P1vYBppLs7SxZjdXVlv8I1s6lP0tJauwQadiOniDiq2vOSTgbeABxZLmmkbaxL0wclXQYcDoybOCwfD0ttZnkU0lQl6Rjg48BxETFcoc4sSTuOzgNH4zGyGsLDUptZHkX1cZwL7AgskbRM0nmQNU1JWpTq7AFcL+lm4Ebgyoj4z2LCndo8LLWZ5VHIPccjYv8K5euA+Wl+JXBwM+OarjwstZnlUUjisNbT2+tEYWa1aYXLcc3MrI04cZiZWS5OHGZmlosTh5mZ5eLEYWZmuThxmJlZLk4cZmaWixOHmZnl4sRhZma5OHGYmVkuThxmZpaLE4eZmeXixGFmZrk4cZiZWS5OHGZmlosTh5mZ5eLEYWZmuThxmJlZLk4cZmaWixOHmZnl4sRhZma5OHGYmVkuThxmZpZLYYlD0tmSbpG0TNJiSXtVqHeypHvS4+Rmx2lmZlsq8ozjnIh4YUQcAvwYOHNsBUm7AWcBLwYOB86StGtzwzQzs1KFJY6IeLRkcRYQZaq9DlgSEQ9FxMPAEuCYZsRnZmblzSxy55IWAO8E/gQcUabK3sB9JctrU1m5bfUBfQCdnZ2TG6iZmf1VQ884JF0taXmZx/EAEdEfEfsCg8Bp5TZRpqzcmQkRMRARPRHRM2fOnMl7EWZmtoWGnnFExFE1Vr0YuJKsP6PUWuA1Jcv7ANfUHZiZmU1YkVdVzS1ZPA64s0y1q4CjJe2aOsWPTmVmZlaQIvs4Fko6EBgBVgOnAkjqAU6NiFMi4iFJZwO/Set8NiIeKiZcMzMDUETZLoO21tPTE0NDQ0WHYWbWNiQtjYieWur6l+NmZpaLE4eZmeXixGFmZrk4cZiZ1WlwELq7YcaMbDo4WHREjVXoL8fNzNrd4CD09cHwcLa8enW2DNDbW1xcjeQzDjOzOvT3b04ao4aHs/KpyonDbAqYbk0lrWTNmnzlU4ETh1mbG20qWb0aIjY3lTh5NEelMVWn8lirThxmbW46NpW0kgULoKNjy7KOjqx8qnLiMGtz07GppJX09sLAAHR1gZRNBwambsc4+Koqs7bX2Zk1T5Urt+bo7Z3aiWIsn3GYtbnp2FRixXLiMGtz07GpxIrlpiqzKWC6NZVYsXzGYWZmuThxmJlZLk4cZmYFa7df/ruPw8ysQO04SKLPOMzMCtSOv/x34jAzK1A7/vLficPMrEDtOEiiE4eZWYHa8Zf/ThxmZgWajF/+N/uqrEKuqpJ0NnA8MAI8CLwrItaVqfc0cGtaXBMRxzUvSjOz5qjnl/9FXJWliGjMlqvtVNopIh5N8x8EDoqIU8vUeywidsi7/Z6enhgaGpqESM3MWlt3d/nRkbu6YNWq2rcjaWlE9NRSt5CmqtGkkcwCmp+9zMymgCKuyiqsj0PSAkn3Ab3AmRWqbSdpSNINkt44zvb6Ut2hDRs2THq8ZmatqIirshqWOCRdLWl5mcfxABHRHxH7AoPAaRU205lOnd4O/Juk51baX0QMRERPRPTMmTNn0l+PmVkrKuKqrIZ1jkfEUTVWvRi4EjirzDbWpelKSdcALwLunawYzcza3WgHeH9/1jzV2ZkljUYOV1LUVVVzI+KetHgccGeZOrsCwxHxpKTZwMuBLzQxTDOzttDs+7EUNcjhQkkHkl2Ouxo4FUBSD3BqRJwCPA/4qqQRsia1hRFxe0HxmplZUkjiiIi3VCgfAk5J878EXtDMuMzMbHz+5biZmeXixGFmZrk4cZiZWS6FDDnSaJI2kHW6T8Rs4A+TGM5kc3z1cXz1cXz1aeX4uiKiph/BTcnEUQ9JQ7WO11IEx1cfx1cfx1efVo+vVm6qMjOzXJw4zMwsFyeOZxooOoBxOL76OL76OL76tHp8NXEfh5mZ5eIzDjMzy8WJw8zMcnHiSCQdI+kuSSsknVFQDPtK+pmkOyTdJulDqfzTkn4vaVl6zC9Z5xMp5rskva4JMa6SdGuKYyiV7SZpiaR70nTXVC5J/57iu0XSoQ2O7cCSY7RM0qOSPlz08ZN0vqQHJS0vKct9zCSdnOrfI+nkBsZ2jqQ70/4vk7RLKu+W9ETJcTyvZJ3D0t/FihS/JiO+KjHmfk8b9T9eIb5LS2JbJWlZKi/kGE66iJj2D2Arsvt8PAfYBriZ7D7ozY5jT+DQNL8jcDdwEPBp4KNl6h+UYt0W2C+9hq0aHOMqYPaYsi8AZ6T5M4DPp/n5wE8AAS8Bft3k9/R+oKvo4we8CjgUWD7RYwbsBqxM013T/K4Niu1oYGaa/3xJbN2l9cZs50bgpSnunwDzGnz8cr2njfwfLxffmOf/D3Bmkcdwsh8+48gcDqyIiJUR8RRwCXB8s4OIiPURcVOa3wjcAexdZZXjgUsi4smI+B2wguy1NNvxwLfS/LeAN5aUXxiZG4BdJO3ZpJiOBO6NiGojCDTl+EXEdcBDZfad55i9DlgSEQ9FxMPAEuCYRsQWEYsjYlNavAHYp9o2Unw7RcSvIvsEvLDk9dStwvGrpNJ72rD/8WrxpbOGtwLfqbaNRh/DyebEkdkbuK9keS3VP7AbTlI32R0Pf52KTktNB+ePNmtQTNwBLJa0VFJfKtsjItZDlvyA3QuMb9SJbPnP2irHb1TeY1ZUrO8h+/Y7aj9Jv5V0raRXprK9UzzNji3Pe1rU8Xsl8EBsvnEdtNYxnBAnjky5tsTCrlOWtAPwfeDDEfEo8BXgucAhwHqyU18oJu6XR8ShwDzgA5JeVaVuIcdV0jZkd5b8XipqpeM3nkoxNT1WSf3AJmAwFa0HOiPiRcA/ARdL2qmI2Mj/nhb1Xp/Ell9gWukYTpgTR2YtsG/J8j7AuiICkbQ1WdIYjIgfAETEAxHxdESMAF9jc3NK0+OOzfeBfxC4LMXywGgTVJo+WFR8yTzgpoh4IMXaMsevRN5j1tRYU+f7G4De1HRCav75Y5pfStZncECKrbQ5qxl/h3nf06a/15JmAm8GLi2Ju2WOYT2cODK/AeZK2i99Wz0RuKLZQaT20G8Ad0TEF0vKS/sF3gSMXr1xBXCipG0l7QfMJetga1R8syTtODpP1om6PMUxepXPycDlJfG9M10p9BLgT6PNMw22xbe8Vjl+Y+Q9ZlcBR0vaNTXLHJ3KJp2kY4CPA8dFxHBJ+RxJW6X555Adr5Upvo2SXpL+ht9Z8noaYgLvaRH/40cBd0bEX5ugWukY1qXo3vlWeZBdzXI32TeA/oJieAXZ6ektwLL0mA98G7g1lV8B7FmyTn+K+S4afBUG2RUpN6fHbaPHCXgW8FPgnjTdLZUL+HKK71agpwnHsAP4I7BzSVmhx48sia0H/kL2zfK9EzlmZP0NK9Lj3Q2MbQVZf8Do3+B5qe5b0vt+M3ATcGzJdnrIPrzvBc4ljUrRwBhzv6eN+h8vF18q/yZw6pi6hRzDyX54yBEzM8vFTVVmZpaLE4eZmeXixGFmZrk4cZiZWS5OHGZmlosTh01Zkp5VMgrp/WNGU92mxm1cIOnAcep8QFLvJMV8fRrBdTTOS8dfK9f21yqNdms2Ub4c16YFSZ8GHouIfxlTLrL/g5FCAhtD0vXAaRGxrEHbXws8PyIeacT2bXrwGYdNO5L2l7Q83QvhJmBPSQOShpTdB+XMkrrXSzpE0kxJj0haKOlmSb+StHuq8zlJHy6pv1DSjenM4WWpfJak76d1v5P2dUiOmC+S9BVJP5d0t6R5qXx7Sd9Sdh+Hm0bHDkvx/mt6nbdI+seSzX1Y2SB7t0g6oO4DatOOE4dNVwcB34iIF0XE78nujdEDHAy8VtJBZdbZGbg2Ig4GfkX2S+5yFBGHAx8DRpPQ6cD9ad2FZCMfV1J6E6CFJeX7Aq8GjgUGJG0LfBB4KiJeAPwD8O3UDPd+YC/g4Ih4Idkw4qMeiGyQva+TDbRnlsvMogMwK8i9EfGbkuWTJL2X7H9iL7LEcvuYdZ6IiNEhxpeSDZldzg9K6nSn+VeQ3RSJiLhZ0m1VYntbhaaq76Ymtbsk3Uc2ztErgHPSdm+TtA7Yn2ycpH+LiKfTc6X3iyiNbz5mOTlx2HT1+OiMpLnAh4DDI+IRSRcB25VZ56mS+aep/P/zZJk6k3Eb0LEdkpWGCx/dX6UOzHLxmdXMTVVmsBOwEXhUm++2N9muJ7sTHJJeQHZGk9cJadTcA8iare4BrgN603afR3b74RXAYuD9JSOx7lb3KzBL/G3DLOsgv51sZNKVwC8asI8vARdKuiXtbznwpwp1L5X0RJp/ICJGE9kKskSxO9AXEU9J+hLwVUm3ko3O+s5U/lWypqxbJG0iu/HReQ14XTYN+XJLM4CUAAAAXklEQVRcsyZQdlOfmRHx59Q0thiYG5vv7T3e+hcB/y8iftjIOM1q4TMOs+bYAfhpSiAC3ldr0jBrNT7jMDOzXNw5bmZmuThxmJlZLk4cZmaWixOHmZnl4sRhZma5/Bcb1Jaffax+OwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7fc9ecbf3518>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "### Equalizer then CFO ###\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt   \n",
    "import numpy as np\n",
    "\n",
    "epochs = 2000\n",
    "batch_size = 1500\n",
    "\n",
    "#     # generate the data\n",
    "#     (channel_train, channel_test, preamble_train_orig, preamble_train_cfo, preamble_test_orig, \n",
    "# preamble_test_cfo, data_train_orig,data_train_cfo,data_test_orig,data_test_cfo) = data_generation(epochs)\n",
    "\n",
    "mc_losses = []\n",
    "\n",
    "# define equalizer NN class object\n",
    "net = Equalizer_then_CFO(preamble_length = 50, data_length = 100, channel_length = 2, \n",
    "                    batch_size=1500, learning_rate = 0.001)\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    current_start = epoch*batch_size\n",
    "    channel_train_batch = channel_train[current_start:(current_start+batch_size),:,:]\n",
    "    data_train_cfo_batch = data_train_cfo[current_start:(current_start+batch_size),:,:]\n",
    "    data_train_orig_batch = data_train_orig[current_start:(current_start+batch_size),:,:]\n",
    "    preamble_train_orig_batch = preamble_train_orig[current_start:(current_start+batch_size),:,:]\n",
    "    preamble_train_cfo_batch = preamble_train_cfo[current_start:(current_start+batch_size),:,:]        \n",
    "\n",
    "    train_cost = net.train_net(preamble_train_orig_batch, preamble_train_cfo_batch, data_train_orig_batch, \n",
    "                               data_train_cfo_batch, channel_train_batch)\n",
    "\n",
    "    if epoch % 100 == 0: \n",
    "\n",
    "\n",
    "        test_cost, test_est, test_omega, test_corr = net.test_net(preamble_test_orig, \n",
    "                                                                  preamble_test_cfo, data_test_orig, \n",
    "                                                                  data_test_cfo, channel_test)\n",
    "\n",
    "        mc_losses.append(test_cost)\n",
    "\n",
    "\n",
    "        plt.plot(epoch, np.log(test_cost), 'bo')\n",
    "        print('Epoch {}, Train Cost {}, Test Cost: {}'.format(epoch, train_cost,test_cost))\n",
    "\n",
    "\n",
    "plt.xlabel('Training Epoch')\n",
    "plt.ylabel('L2 Test Cost')\n",
    "plt.title('NN Equalizer: no noise')\n",
    "# plt.text(1000, .025, r'NN equalizer')\n",
    "# plt.text(0.5, .025, r'Zero Force equalizer')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.16234691 -0.03970928]\n",
      " [-0.19225084 -0.17413898]\n",
      " [ 0.09435235 -0.16037185]\n",
      " [ 0.03150128  0.02534925]\n",
      " [ 0.0396372  -0.08908354]\n",
      " [ 0.26403277 -0.10009514]\n",
      " [ 0.11368872  0.07741069]\n",
      " [-0.21185086 -0.10727884]\n",
      " [-0.15390115 -0.12785102]\n",
      " [ 0.12514209  0.27123858]\n",
      " [ 0.03433518  0.11138798]\n",
      " [-0.02079661 -0.41602449]\n",
      " [-0.12422998 -0.09372808]\n",
      " [-0.18149359  0.15800654]\n",
      " [-0.12807034 -0.05179549]\n",
      " [-0.0532961   0.07174487]\n",
      " [-0.23837666  0.06719859]\n",
      " [-0.02387224  0.09061922]\n",
      " [ 0.06423442  0.28427605]\n",
      " [ 0.12809862 -0.20461286]\n",
      " [-0.09872968 -0.30257179]\n",
      " [-0.11483188  0.05212306]\n",
      " [ 0.02188127 -0.02777295]\n",
      " [-0.07905428 -0.19221901]\n",
      " [-0.06869019 -0.02545906]\n",
      " [-0.13699384  0.08126296]\n",
      " [ 0.25910215  0.26435666]\n",
      " [-0.19925035 -0.07228382]\n",
      " [-0.35770605  0.09340321]\n",
      " [ 0.1479869   0.04556532]\n",
      " [ 0.26893572  0.04971458]\n",
      " [ 0.21770821 -0.03788875]\n",
      " [ 0.10874559 -0.08441781]\n",
      " [-0.13635941  0.09711163]\n",
      " [ 0.46304437 -0.22444283]\n",
      " [ 0.15404083  0.01289208]\n",
      " [-0.0572278   0.15480952]\n",
      " [ 0.05287291  0.00403894]\n",
      " [ 0.10617195  0.30916558]\n",
      " [ 0.25115283 -0.17489218]\n",
      " [ 0.08022864 -0.20382984]\n",
      " [-0.08946497  0.10443734]\n",
      " [ 0.19423999 -0.16095875]\n",
      " [ 0.28807579  0.02707393]\n",
      " [ 0.51565693 -0.48704176]\n",
      " [ 0.024835   -0.28290169]\n",
      " [ 0.02395494 -0.08371098]\n",
      " [ 0.12882764 -0.26464311]\n",
      " [-0.44002215 -0.02136834]\n",
      " [-0.32562666 -0.05823125]]\n",
      "[[ 8.70706253e-04]\n",
      " [-2.08767864e-03]\n",
      " [ 3.98835692e-03]\n",
      " [-1.53487401e-03]\n",
      " [-1.30424724e-03]\n",
      " [-3.36907828e-04]\n",
      " [-1.75556069e-03]\n",
      " [ 1.82798600e-03]\n",
      " [-6.84431790e-04]\n",
      " [ 1.29040862e-03]\n",
      " [-2.92847746e-03]\n",
      " [ 1.07195567e-03]\n",
      " [ 3.15297222e-03]\n",
      " [ 1.63372877e-03]\n",
      " [-1.58085195e-04]\n",
      " [-1.09488368e-03]\n",
      " [-4.79475050e-03]\n",
      " [ 4.66632685e-03]\n",
      " [ 4.91197630e-04]\n",
      " [-1.36295413e-03]\n",
      " [ 8.08767512e-04]\n",
      " [-1.30800533e-03]\n",
      " [-2.63578526e-03]\n",
      " [-2.74557066e-05]\n",
      " [ 1.71585799e-04]\n",
      " [-1.91792887e-03]\n",
      " [-3.00278210e-03]\n",
      " [ 2.20945097e-03]\n",
      " [-1.68286110e-03]\n",
      " [-4.83139433e-05]\n",
      " [-1.02539297e-04]\n",
      " [ 4.40826854e-03]\n",
      " [ 1.87148167e-03]\n",
      " [ 3.14971199e-03]\n",
      " [ 5.49578401e-04]\n",
      " [ 4.34000119e-04]\n",
      " [ 4.51547723e-03]\n",
      " [-2.74884936e-03]\n",
      " [ 4.10348435e-03]\n",
      " [-8.61319954e-04]\n",
      " [ 1.79974112e-03]\n",
      " [ 7.48622374e-04]\n",
      " [-1.57100949e-03]\n",
      " [ 1.28178602e-03]\n",
      " [ 1.65170383e-03]\n",
      " [-1.22516577e-03]\n",
      " [ 4.30768128e-04]\n",
      " [ 1.02460844e-03]\n",
      " [-1.45917901e-03]\n",
      " [ 3.65326947e-03]\n",
      " [-6.89508713e-04]\n",
      " [-3.11101818e-03]\n",
      " [-1.66675674e-03]\n",
      " [ 1.71572471e-03]\n",
      " [-1.78502462e-05]\n",
      " [ 3.16935190e-03]\n",
      " [ 1.42577054e-03]\n",
      " [-1.62822448e-04]\n",
      " [ 3.32524742e-03]\n",
      " [ 1.08983034e-03]\n",
      " [ 1.10759906e-03]\n",
      " [ 4.26032159e-04]\n",
      " [ 2.01130865e-03]\n",
      " [ 4.86406792e-05]\n",
      " [ 1.10419358e-03]\n",
      " [ 5.31857394e-03]\n",
      " [ 1.01532870e-03]\n",
      " [ 9.26482681e-04]\n",
      " [ 4.07328563e-03]\n",
      " [ 1.19197184e-03]\n",
      " [ 8.85305395e-04]\n",
      " [ 5.43692517e-04]\n",
      " [ 1.13627445e-03]\n",
      " [-6.97269793e-04]\n",
      " [-1.23938084e-03]\n",
      " [ 2.35947731e-04]\n",
      " [ 5.30514414e-03]\n",
      " [ 1.80585904e-03]\n",
      " [ 2.74062866e-03]\n",
      " [-1.21287187e-03]\n",
      " [-1.61524962e-03]\n",
      " [ 1.86839027e-03]\n",
      " [-4.72058153e-04]\n",
      " [ 1.54962565e-03]\n",
      " [-3.93159218e-03]\n",
      " [-3.41339064e-03]\n",
      " [-7.73043777e-04]\n",
      " [ 8.98975380e-04]\n",
      " [ 1.60523858e-03]\n",
      " [ 8.09461893e-04]\n",
      " [ 8.88733861e-04]\n",
      " [-2.03706255e-03]\n",
      " [ 4.17429780e-04]\n",
      " [ 7.86293786e-04]\n",
      " [ 2.03520269e-03]\n",
      " [ 3.33748089e-05]\n",
      " [ 1.30299130e-04]\n",
      " [ 6.84678390e-04]\n",
      " [ 1.43547470e-04]\n",
      " [-2.62185243e-05]\n",
      " [ 4.20015211e-03]\n",
      " [-8.14933242e-06]\n",
      " [ 1.20065207e-03]\n",
      " [-6.88031606e-04]\n",
      " [-9.77563897e-06]\n",
      " [-7.32635153e-04]\n",
      " [ 1.70311239e-03]\n",
      " [-5.19336456e-04]\n",
      " [ 1.15500445e-03]\n",
      " [ 2.05064065e-03]\n",
      " [ 1.03185705e-03]\n",
      " [-2.96883284e-04]\n",
      " [ 1.93541740e-03]\n",
      " [-8.24239506e-04]\n",
      " [ 2.23221765e-03]\n",
      " [ 2.69795043e-03]\n",
      " [-1.56873732e-03]\n",
      " [ 2.96107487e-03]\n",
      " [ 1.09844401e-04]\n",
      " [ 2.76010622e-03]\n",
      " [-1.37568341e-03]\n",
      " [ 1.47128856e-03]\n",
      " [-5.22756806e-04]\n",
      " [ 2.02568302e-03]\n",
      " [-8.98572801e-04]\n",
      " [ 3.35178582e-03]\n",
      " [-1.67478536e-03]\n",
      " [ 3.64977071e-03]\n",
      " [ 2.91387246e-04]\n",
      " [-4.87882171e-04]\n",
      " [ 9.86637680e-04]\n",
      " [-1.05125921e-03]\n",
      " [-2.50463025e-03]\n",
      " [ 1.85242282e-03]\n",
      " [ 2.12218832e-03]\n",
      " [-1.60856932e-03]\n",
      " [ 1.33420996e-03]\n",
      " [-2.57728586e-03]\n",
      " [ 1.36118081e-03]\n",
      " [ 1.77601120e-04]\n",
      " [-1.04667858e-03]\n",
      " [ 2.22200297e-03]\n",
      " [-7.50578064e-04]\n",
      " [-2.65677180e-04]\n",
      " [-1.61209613e-04]\n",
      " [ 3.96747217e-04]\n",
      " [-2.92213385e-03]\n",
      " [ 2.85602270e-03]\n",
      " [ 5.52380587e-04]\n",
      " [ 6.20348544e-04]]\n",
      "[0.01028296]\n",
      "[0.00953434]\n"
     ]
    }
   ],
   "source": [
    "k=41\n",
    "print(test_est[0,:,:]-preamble_test_orig[0,:,:])\n",
    "\n",
    "print(test_omega-omega_test)\n",
    "\n",
    "print(test_omega[k])\n",
    "print(omega_test[k])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
